import torch

torch.random.manual_seed(10)
data1 = torch.randint(0, 10, [3, 4, 5, 6])
torch.random.manual_seed(10)
data2 = torch.randint(0, 10, [3, 4, 3, 6])

print(torch.cat([data1, data2], dim=2).shape)
